import numpy as np
import pandas as pd
import statsmodels.api as sm
import math
import sys
sys.path.insert(1,'/REDACTED/fairness/code/utilities')
import estimators as et


## Linear Estimator
def regEstimate(dataset, pbvarb = 'predicted_prob_black', outcome = 'audited'):
        model = sm.OLS.from_formula(outcome+' ~ '+pbvarb, data = dataset).fit(cov_type = 'HC1')
        coef = model.params[pbvarb]
        se = model.bse[pbvarb]
        black_aud_rate = model.params[0] + model.params[1]
        non_black_aud_rate = model.params[0] 
        print("number of obs :" + str(model.nobs))
        return coef, se, black_aud_rate, non_black_aud_rate

## Linear Estimator with controls
def regEstimate_control(dataset, pbvarb = 'predicted_prob_black', controls = '', outcome = 'audited'):
        model = sm.OLS.from_formula(outcome+' ~ '+pbvarb+' + '+controls, data = dataset).fit(cov_type = 'HC1')
        coef = model.params[pbvarb]
        se = model.bse[pbvarb]
        results = model.summary()
        print("number of obs :"+ str(model.nobs))
        return coef, se

## Probabilistic Estimator 
def chenEstimate(dataset, pbvarb = 'predicted_prob_black', outcome = 'audited'):
        black_aud_rate = (dataset[pbvarb]*dataset[outcome]).sum()/(dataset[pbvarb]).sum()
        non_black_aud_rate = ((1-dataset[pbvarb])*dataset[outcome]).sum()/(1-dataset[pbvarb]).sum()
        est =  black_aud_rate - non_black_aud_rate
        return est, black_aud_rate, non_black_aud_rate

## Get standard errors
def getSEs(dataset,pbvarb='predicted_prob_black',outcome='audited', seReg = 0.0):
        seMultiplier=getSEMultiplier(dataset,pbvarb)
        #seReg = regEstimate(dataset,pbvarb,outcome)[1]
        seChen = seReg*seMultiplier
        return seReg, seChen

def getSEMultiplier(dataset,pbvarb='predicted_prob_black'):
        return np.sqrt(dataset[pbvarb].var()/(dataset[pbvarb].mean()*(1-dataset[pbvarb].mean())))

### OP AUDIT
dataBISG = pd.read_csv("/REDACTED/data/final/individualBISG2014_full_final.csv")
dataBISG= dataBISG[~dataBISG.predicted_prob_black.isna()]

nc_BISG = dataBISG[dataBISG['state']=="NC"]
nc_BISG['pprob_black'] = nc_BISG['predicted_prob_black']
nc_BISG = nc_BISG[['taxpayer_id', 'pprob_black']]

##NC data
ncdata = pd.read_csv("/REDACTED/NC_Analysis_Dataset_2014_tplevel.csv")

nc_weight = pd.read_stata('/REDACTED/BIFSG/nc_rewt_2014_coded_output_v4_December_ncsamp.dta')

#switch weights here?
##########################################################
#### Table A.4
##########################################################

# merge data together and clean
ncdata = ncdata[ncdata.black_ind.notna()]
nc_weight = nc_weight[['taxpayer_id', 'black_prob', 'uswgt']]

ncdata = pd.merge(ncdata,nc_weight ,on='taxpayer_id')

ncdata = pd.merge(ncdata,nc_BISG,on='taxpayer_id')
ncdata.drop(['predicted_prob_black', 'black_prob'], axis = 1, inplace = True)
ncdata['predicted_prob_black'] = ncdata['pprob_black'] 
ncdata.drop(['pprob_black'], axis = 1, inplace = True)
ncdata['p_black_rd'] = ncdata['predicted_prob_black'].round(2)

ncdata['unitwt']=1
ncdata['aud_no_research_audits'] = [1 if (x.find('[80]') == -1)
                         and (x.find(' 80]') == -1)
                         and (x.find('[80 ') == -1)
                         and (x.find('[91]') == -1)
                         and (x.find(' 91]') == -1)
                         and (x.find('[91 ') == -1)
                         and y == 1
                         else 0
                         for x, y in zip(ncdata.audit_source_code.astype(str), ncdata.audited)]

ncdata['aud_no_research_audits']=pd.to_numeric(ncdata['aud_no_research_audits'])
ncdata['predicted_prob_black']=pd.to_numeric(ncdata['predicted_prob_black'])
ncdata['p_black_rd']=pd.to_numeric(ncdata['p_black_rd'])
ncdata['uswgt']=pd.to_numeric(ncdata['uswgt'])
ncdata['unitwt']=pd.to_numeric(ncdata['unitwt'])
ncdata['black_ind']=pd.to_numeric(ncdata['black_ind'])


# loop over different subsets and calculate covariance conditions
varlist = ['Metrics', 'Full Pop Unweighted', 'Full Pop Weighted', 'EITC Unweighted', 'EITC Weighted', 'Non-EITC Unweighted', 'Non-EITC Weighted']
output = pd.DataFrame(columns=varlist)
output['Metrics'] = ['E_cov_cond_b', 'E_cov_cond_b_se', 'E_cov_cond_B', 'E_cov_cond_B_se', 'nc_weights']
filters = [None, ~ncdata.taxpayer_id.isna(), ~ncdata.taxpayer_id.isna(), ncdata.eic_ind==1, ncdata.eic_ind==1, ncdata.eic_ind==0, ncdata.eic_ind==0]

for i in range(len(varlist)):
    print(i)
    if varlist[i] == 'Metrics':
        pass
    else:
        column = []
        dat = ncdata[filters[i]]
        if i%2 == 1:
            # unweighted
            est,se,tstat,pval,_ = et.computeExpectedCovariance(dat, weightvarb='unitwt', outcomevarb='aud_no_research_audits', covarb='black_ind', conditioningvarb='p_black_rd')
            column.append(est*10000)
            column.append(se*10000)
            est,se,tstat,pval,_ = et.computeExpectedCovariance(dat, weightvarb='unitwt', outcomevarb='aud_no_research_audits', covarb='predicted_prob_black', conditioningvarb='black_ind')
            column.append(est*10000)
            column.append(se*10000)
            column.append(0)
            output[varlist[i]] = column
            
        elif i%2 == 0:
            # weighted
            est,se,tstat,pval,_ = et.computeExpectedCovariance(dat, weightvarb='uswgt', outcomevarb='aud_no_research_audits', covarb='black_ind', conditioningvarb='p_black_rd')
            column.append(est*10000)
            column.append(se*10000)
            est,se,tstat,pval,_ = et.computeExpectedCovariance(dat, weightvarb='uswgt', outcomevarb='aud_no_research_audits', covarb='predicted_prob_black', conditioningvarb='black_ind')
            column.append(est*10000)
            column.append(se*10000)
            column.append(1)
            output[varlist[i]] = column

# add new row for N
output.loc[len(output)] = ["N", len(ncdata), len(ncdata), (ncdata.eic_ind==1).sum(), (ncdata.eic_ind==1).sum(), (ncdata.eic_ind==0).sum(), (ncdata.eic_ind==0).sum()]

# write out results
output.to_latex('/REDACTED/residual_covariance_terms_SE_NC_post_update.tex', index=False)